Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General plugin mechanism #45355

Merged
merged 1 commit into from
Sep 15, 2022
Merged

General plugin mechanism #45355

merged 1 commit into from
Sep 15, 2022

Conversation

weishengying
Copy link
Contributor

@weishengying weishengying commented Aug 23, 2022

PR types

New features

PR changes

Others

Describe

通用 Plugin 以及 自定义Op Plugin 加载机制

方案要点:
image

当前paddle trt 工作流程图如下:
image

这部分的主要功能就是利用 xxOpConverter 对象将Paddle Op 转换为 TRT Layer。新方案新增两个 converter(creater),能够使以下算子进入TRT。
第一类Op:该Op被通用 Plugin 所支持;
第二类Op:用户自定义Op

image

新增两个 converter:generic_and_custom_plugin_creater.cc
generic_plugin_creater :该 converter 为一些 Op 创建通用Plugin;
custom_plugin_creater : 该 converter 为自定义 Op 创建用户自定义的Plugin;

调整 了OpTell;

对于 paddle 已有的大量 op converter, 加上新增的两个 converter,已有三种converter。对于一个 paddle op 需要使用哪种 converter, 由 OpTeller 判断。

之前 OpTeller 的功能是判断一个 Paddle Op , TRT 支持与否, 其中大量的 if 判断放在了 op_converter.h文件中,这不太合理,op_converter.h 只负责根据Op名字,找到对应的 converter, 而能不能转的 if 等边界条件判断,应当移入 op_tell.cc 中。

由于现在存在三种 converter, 因此 OpTeller 不仅要告知一个 Paddle Op , TRT 支持与否,还需要告知使用哪一种 converter。

enum class OpConverterType {
  Default = 0,
  GenericPluginCreater,
  CustomPluginCreater
};

Default 表示使用框架内部已有的 xx_op_converter。

测试

1. 测试 custom_plugin_creater

通过在 dynamic_shape_infermeta.cc文件中增加 符号化的 shape 推导函数,使得对应的 Op 被通用 plugin 所支持。
在已有的auto_scan的单测中便能够测试 通用plugin对该 Op的支持情况。
如在 dynamic_shape_infermeta.cc 中增加了 gather_nd op 的符号化shape推导函数, 在 test_trt_convert_gather_nd.py单测中就会使用 通用 plugin。

2. 测试 custom_plugin_creater

test_custom_plugin_creater.cc
test_custom_op_plugin.cc

测试 custom_plugin_creater 的功能:能否加载用户自定义 Op 的 Plugin。主要测试 custom_plugin_creater 能否正确加载到用户自定义Op, 获取Op的各种属性信息,以及从 IPluinRegistry 中得到正确的 plugin_creator, 将自定义Op的属性正确传给
plugin_creator,然后成功创建自定义Op的plugin。

规定:
用户自定义Op的 静态shape plugin,以名称 “_paddle_trt_plugin”结尾; 动态shape plugin 以 “_paddle_trt_dynamic_plugin”结尾。
test_custom_op_plugin.cc 文件下为 custom_op 定义了一个静态shape plugin 和 动态 shape plugin。

这两个plugin 直接继承 public nvinfer1::IPluginV2 ,public nvinfer1::IPluginV2DynamicExt,

两个plugin 都没有具体的计算逻辑。主要实现的接口为:

nvinfer1::IPluginV2* createPlugin(
      const char* name,
      const nvinfer1::PluginFieldCollection* fc) noexcept override;

该 api 负责接收自定义Op的属性信息,然后创建 plugin 对象。

@paddle-bot
Copy link

paddle-bot bot commented Aug 23, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Aug 23, 2022

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@weishengying weishengying reopened this Aug 23, 2022
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Aug 23, 2022
@PaddlePaddle PaddlePaddle unlocked this conversation Aug 23, 2022
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unit testing

plugindata.data = &value;
} else {
CHECK(false) << "not incompleted";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not support float, and other dtype ?

Copy link
Contributor Author

@weishengying weishengying Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充完整

plugindatas.push_back(plugindata);
}

nvinfer1::PluginFieldCollection pluginFC{(int32_t)plugindatas.size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“pluginFC” should lowercase with underscore, see https://google.github.io/styleguide/cppguide.html#Variable_Names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}

auto creator =
GetPluginRegistry()->getPluginCreator(op_desc.Type().c_str(), "1");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for why setting plugin version to “1”

auto *var = block_desc.FindVar(arg_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("no variable called %s in block.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no variable called -> There is no variable called

class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line 21 to line 28

const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_desc_ -> op_desc

return false;
}
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenericPlugin对op的输入、输出类型也有限制,这里是否要增加判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 op 做具体判断,后续继续添加通用plugin支持的op时,根据op的具体情况做判断

auto* attr_ptr = attr_reader.GetAttr(attr_name);
switch (attr_defs[k].type_index) {
case phi::AttributeType::SCALAR:
if (attr_ptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 if(attr_prt) 判断是否需要提前到49行之前,下面每个case都加了判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}

template <typename T>
inline std::string vectorToStr(const std::vector<T>& dims) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vectorToStr -> VectorToStr

outputs_data_type_ = outputs_data_type;
}

GenericPlugin::GenericPlugin(void const* serialData, size_t serialLength) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serialData -> serial_data , serialLength -> serial_len

Copy link
Contributor

@hp03 hp03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if ((desc.HasAttr("namescope") &&
PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) ==
"/skip_quant_2/") ||
desc.HasAttr("skip_quant"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 "skip_quant_2" 太hard code了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码是原有的代码。这里我把这些if判断从 op_conveter.h 文件移动到了 opteller中。所以git 显示是我改的。实际上这里逻辑还是原来的逻辑

// only consider dynamic_shape mode
if (!with_dynamic_shape) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这为啥只考虑 dynamic_shape mode?

break;

default:
CHECK(false) << "no OpConverter for optype " << op_desc.Type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输出log信息需要语法正确

op_desc_ = std::move(framework::OpDesc(proto_op_desc_, nullptr));
// op_meta_data_
proto_op_desc_.SerializeToString(&op_meta_data_);
// inputs_data_type_ and outputs_data_type_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些 // 注释看起来无意义吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

op_meta_data_ = std::move(op_meta_data);
// proto_op_desc_
proto_op_desc_.ParseFromString(op_meta_data_);
// op_desc_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

yield self.create_inference_config(), (0, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 4), 1e-5
yield self.create_inference_config(), (1, 3), 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里去掉Half测试的原因是?

Copy link
Contributor Author

@weishengying weishengying Sep 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经改回去了

qingqing01
qingqing01 previously approved these changes Sep 9, 2022
for (auto &attr_name : op_attrs_names) {
nvinfer1::PluginField plugindata;
plugindata.name = attr_name.c_str();
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

貌似可以用宏折叠起来。其他switch case 同

namespace inference {
namespace tensorrt {

nvinfer1::DimsExprs GatherNdInferMeta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有的OpInferMeta都要写在这个文件中吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 集中放在这个文件中

qingqing01
qingqing01 previously approved these changes Sep 9, 2022
Comment on lines 452 to 453
free(dense_tensor_inputs);
free(dense_tensor_outputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用free可以释放掉吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是可以的, 只需要释放申请的一个vector对象,vector对象里面的元素,vector在释放的过程中会去释放

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

析构函数不会被正确调用,有内存泄露风险

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改为delete

qingqing01
qingqing01 previously approved these changes Sep 14, 2022
zhangjun
zhangjun previously approved these changes Sep 14, 2022
Copy link
Contributor

@zhangjun zhangjun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

XiaoguangHu01
XiaoguangHu01 previously approved these changes Sep 14, 2022
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants